{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "03c6be7c",
   "metadata": {},
   "source": [
    "### 代码实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ecb9916",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的库，torchinfo用于查看模型结构\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchinfo import summary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a980a08",
   "metadata": {},
   "source": [
    "### 结构定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "72e3dcde",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义LeNet的网络结构\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        # 卷积层1：输入1个通道，输出6个通道，卷积核大小为5x5，后接BN\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5),\n",
    "            nn.BatchNorm2d(6)\n",
    "        )\n",
    "        # 卷积层2：输入6个通道，输出16个通道，卷积核大小为5x5，后接BN\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.BatchNorm2d(16)\n",
    "        )\n",
    "        # 全连接层1：输入16x4x4=256个节点，输出120个节点，由于输入数据略有差异，修改为16x4x4\n",
    "        self.fc1 = nn.Sequential(\n",
    "            nn.Linear(16 * 4 * 4, 120),\n",
    "            nn.BatchNorm1d(120)\n",
    "        )\n",
    "        # 全连接层2：输入120个节点，输出84个节点\n",
    "        self.fc2 = nn.Sequential(\n",
    "            nn.Linear(120, 84),\n",
    "            nn.BatchNorm1d(84)\n",
    "        )\n",
    "        # 输出层：输入84个节点，输出10个节点\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 使用ReLU激活函数，并进行最大池化\n",
    "        x = torch.relu(self.conv1(x))\n",
    "        x = nn.functional.max_pool2d(x, 2)\n",
    "        # 使用ReLU激活函数，并进行最大池化\n",
    "        x = torch.relu(self.conv2(x))\n",
    "        x = nn.functional.max_pool2d(x, 2)\n",
    "        # 将多维张量展平为一维张量\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        # 全连接层\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        # 全连接层\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        # 全连接层\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05abf6fa",
   "metadata": {},
   "source": [
    "### 网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1529d62",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "==========================================================================================\n",
       "Layer (type:depth-idx)                   Output Shape              Param #\n",
       "==========================================================================================\n",
       "LeNet                                    [1, 10]                   --\n",
       "├─Sequential: 1-1                        [1, 6, 24, 24]            --\n",
       "│    └─Conv2d: 2-1                       [1, 6, 24, 24]            156\n",
       "│    └─BatchNorm2d: 2-2                  [1, 6, 24, 24]            12\n",
       "├─Sequential: 1-2                        [1, 16, 8, 8]             --\n",
       "│    └─Conv2d: 2-3                       [1, 16, 8, 8]             2,416\n",
       "│    └─BatchNorm2d: 2-4                  [1, 16, 8, 8]             32\n",
       "├─Sequential: 1-3                        [1, 120]                  --\n",
       "│    └─Linear: 2-5                       [1, 120]                  30,840\n",
       "│    └─BatchNorm1d: 2-6                  [1, 120]                  240\n",
       "├─Sequential: 1-4                        [1, 84]                   --\n",
       "│    └─Linear: 2-7                       [1, 84]                   10,164\n",
       "│    └─BatchNorm1d: 2-8                  [1, 84]                   168\n",
       "├─Linear: 1-5                            [1, 10]                   850\n",
       "==========================================================================================\n",
       "Total params: 44,878\n",
       "Trainable params: 44,878\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (M): 0.29\n",
       "==========================================================================================\n",
       "Input size (MB): 0.00\n",
       "Forward/backward pass size (MB): 0.08\n",
       "Params size (MB): 0.18\n",
       "Estimated Total Size (MB): 0.26\n",
       "=========================================================================================="
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看模型结构及参数量，input_size表示示例输入数据的维度信息\n",
    "summary(LeNet(), input_size=(1, 1, 28, 28))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ad16a80",
   "metadata": {},
   "source": [
    "### 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "96bf0385",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 Loss: 2.2088656986293165 Acc: 0.9578\n",
      "Epoch: 2 Loss: 1.4376559603001913 Acc: 0.979  \n",
      "Epoch: 4 Loss: 1.228520721191793 Acc: 0.9803  \n",
      "Epoch: 6 Loss: 1.106042682456222 Acc: 0.9859  \n",
      "Epoch: 8 Loss: 1.0158490855052476 Acc: 0.9883 \n",
      "100%|██████████| 10/10 [01:56<00:00, 11.62s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAm20lEQVR4nO3deXxU9b3/8dcnM1kI2UlIyAJBRBYDSSQooEWsdetVccG21rpQFfnV9tfd1t7e2/bW+6u/trf7Yql1q9XqT7S17lK1qIAakH1RRAIJECBkYwlk+f7+mEkISEgIk5yZyfv5eMwjM+d8c86Hgbz55nu+5zvmnENERCJfjNcFiIhIaCjQRUSihAJdRCRKKNBFRKKEAl1EJEr4vTpxZmamKyws9Or0IiIRaenSpbudc1nH2udZoBcWFlJeXu7V6UVEIpKZVXS1T0MuIiJRQoEuIhIlFOgiIlHCszF0EYluzc3NVFZW0tTU5HUpESkhIYH8/HxiY2N7/D0KdBHpE5WVlSQnJ1NYWIiZeV1ORHHOUVNTQ2VlJSNHjuzx92nIRUT6RFNTE0OGDFGY94KZMWTIkBP+7abbQDezAjN71czWmtkaM/vyMdpcZ2YrzWyVmS0ys+ITqkJEopLCvPd68971pIfeAnzdOTcemALcbmbjj2rzIXCuc24C8ENg3glX0kMbdzbyg3+s4VBLW1+dQkQkInUb6M657c65ZcHnjcA6IO+oNoucc7XBl0uA/FAX2m7rngPc/+ZmXlm/s69OISJRIikpyesS+tUJjaGbWSFQCrx1nGY3A8+fRE3H9bHRmWQlx/PE0sq+OoWISETqcaCbWRIwH/iKc66hizbnEQj0b3Wxf46ZlZtZ+a5du3pTL35fDFeV5vHahp3s3nuwV8cQkYHFOcc3v/lNioqKmDBhAo899hgA27dvZ/r06ZSUlFBUVMTrr79Oa2srN910U0fbn//85x5X33M9mrZoZrEEwvwvzrknu2gzEbgXuMQ5V3OsNs65eQTH18vKynr92XdXT8rnDws38ffl27j5nJ5P6RERb/zgH2tYu+2Y/cBeG5+bwvcuO71HbZ988kmWL1/OihUr2L17N5MnT2b69Ok88sgjXHTRRfz7v/87ra2t7N+/n+XLl1NVVcXq1asBqKurC2ndfakns1wM+BOwzjn3sy7aDAeeBK53zr0X2hI/6rTsZCbmp2rYRUR65I033uDaa6/F5/ORnZ3NueeeyzvvvMPkyZO5//77+f73v8+qVatITk7mlFNOYdOmTXzpS1/ihRdeICUlxevye6wnPfSzgeuBVWa2PLjtO8BwAOfcPcB/AkOA3wWn2rQ458pCXm0nsybl859/X8OabfWcnpval6cSkZPU0550f5s+fToLFy7k2Wef5aabbuJrX/saN9xwAytWrODFF1/knnvu4fHHH+e+++7zutQe6ckslzecc+acm+icKwk+nnPO3RMMc5xztzjn0jvt79MwB7hsYi5xvhj10kWkWx/72Md47LHHaG1tZdeuXSxcuJAzzzyTiooKsrOzufXWW7nllltYtmwZu3fvpq2tjauvvpq77rqLZcuWeV1+j0Xsrf/pg+P4xPih/H35Nu68ZBxxft30KiLHduWVV7J48WKKi4sxM3784x+Tk5PDgw8+yE9+8hNiY2NJSkrioYceoqqqitmzZ9PWFrjX5Uc/+pHH1fecOdfra5MnpayszJ3sB1y8sr6azz9QzrzrJ3Hh6TkhqkxEQmHdunWMGzfO6zIi2rHeQzNb2tUoSER3a6ePziIzSXPSRUQgwgPd74vhytJcXlm/kxrNSReRAS6iAx0Cc9Jb2hx/X77N61JERDwV8YE+NieFCXmpzF+mYRcRGdgiPtAhMCd9zbaGkN+JJiISSaIi0C8vziXWZ+qli8iAFhWBnj44jvPHZvO3d6tobtU66SIyMEVFoENg2KVm3yFe29C7VRxFRHqrpaXF6xKAKAr0c8dkkZkUxxNLt3pdioiEkSuuuIJJkyZx+umnM29e4MPUXnjhBc444wyKi4s5//zzAdi7dy+zZ89mwoQJTJw4kfnz5wNHfkjGE088wU033QTATTfdxNy5cznrrLO44447ePvtt5k6dSqlpaVMmzaNDRs2ANDa2so3vvENioqKmDhxIr/+9a955ZVXuOKKKzqO+/LLL3PllVee9J81Ym/9P1qsL4YrSvJ4cPFm9uw7RMbgOK9LEpF2z38bdqwK7TFzJsAld3fb7L777iMjI4MDBw4wefJkZs6cya233srChQsZOXIke/bsAeCHP/whqamprFoVqLO2tvZ4hwWgsrKSRYsW4fP5aGho4PXXX8fv97NgwQK+853vMH/+fObNm8fmzZtZvnw5fr+fPXv2kJ6ezhe+8AV27dpFVlYW999/P5///OdP7v0ginroEJiT3tzqeHp5ldeliEiY+NWvfkVxcTFTpkxh69atzJs3j+nTpzNyZOCzFDIyMgBYsGABt99+e8f3paend3vsa665Bp/PB0B9fT3XXHMNRUVFfPWrX2XNmjUdx73tttvw+/0d5zMzrr/+eh5++GHq6upYvHgxl1xyyUn/WaOmhw4wblgKRXkpPLGskpvO1gdfiISNHvSk+8Jrr73GggULWLx4MYmJicyYMYOSkhLWr1/f42MElwQHoKmp6Yh9gwcP7nj+H//xH5x33nk89dRTbN68mRkzZhz3uLNnz+ayyy4jISGBa665piPwT0ZU9dABrj4jn9VVDazbrjnpIgNdfX096enpJCYmsn79epYsWUJTUxMLFy7kww8/BOgYcrngggv47W9/2/G97UMu2dnZrFu3jra2Np566qnjnisvLw+ABx54oGP7BRdcwB/+8IeOC6ft58vNzSU3N5e77rqL2bNnh+TPG3WBPrMkLzAnXQt2iQx4F198MS0tLYwbN45vf/vbTJkyhaysLObNm8dVV11FcXExn/70pwH47ne/S21tLUVFRRQXF/Pqq68CcPfdd3PppZcybdo0hg0b1uW57rjjDu68805KS0uPmPVyyy23MHz4cCZOnEhxcTGPPPJIx77rrruOgoKCkK1KGdHL53bltj+Xs7SijsV3fpxYX9T9nyUSEbR8bve++MUvUlpays0333zM/QNq+dyuzJpUwO69B1n4nuaki0h4mjRpEitXruRzn/tcyI4ZVRdF280Yk8WQwXE8sbSS88dle12OiMhHLF26NOTHjMoeeqwvhitK81iwrprafYe8LkdkwPJqSDca9Oa9i8pAh8Bsl+ZWx9MrtE66iBcSEhKoqalRqPeCc46amhoSEhJO6PuicsgFYHxuCuOHpTB/WSU3Tiv0uhyRASc/P5/Kykp27dK1rN5ISEggPz//hL4nagMdAgt2/dcza9mwo5ExOclelyMyoMTGxnbcjSn9o9shFzMrMLNXzWytma0xsy8fo42Z2a/MbKOZrTSzM/qm3BMzsyQXf4zWSReRgaEnY+gtwNedc+OBKcDtZjb+qDaXAKODjznA70NaZS8NSYrn42OH8uSyKlq0TrqIRLluA905t905tyz4vBFYB+Qd1Wwm8JALWAKkmVnXt1T1o6sn5QfmpL+vcTwRiW4nNMvFzAqBUuCto3blAZ0XIq/ko6GPmc0xs3IzK++vCyXnjRlKRnBOuohINOtxoJtZEjAf+IpzrlcrXznn5jnnypxzZVlZWb05xAmL88cwsySXBWt3Urdfc9JFJHr1KNDNLJZAmP/FOffkMZpUAQWdXucHt4WFWZPyOdTaxj80J11EolhPZrkY8CdgnXPuZ100exq4ITjbZQpQ75zbHsI6T8rpuamMG5aiYRcRiWo96aGfDVwPfNzMlgcfnzSzuWY2N9jmOWATsBH4I/CFvim392ZNymdFZT3vVTd6XYqISJ/o9sYi59wbgHXTxgG3H6+N12aW5PKj59Yxf2kld35SS3qKSPSJ2rVcjpaZFM+MMUN56l3NSReR6DRgAh0Cwy47Gw/y+sbdXpciIhJyAyrQPz52KOmJsbo4KiJRaUAFemBOeh4vr6mmfn+z1+WIiITUgAp0ODwn/emVmpMuItFlwAX66bkpjM1JZr6GXUQkygy4QDczZk3KZ/nWOjbu1Jx0EYkeAy7QAWaW5OGLMZ5YGjarE4iInLQBGehZyfGcNyaLp96tpLVNn3coItFhQAY6BC6OVjcc5HWtky4iUWLABvp5Y4eSpjnpIhJFBmygx/t9zCzO5aW11dQf0Jx0EYl8AzbQAWZNKuBQSxvPaE66iESBAR3oRXkpjMlO1rCLiESFAR3o7XPS391Sx8ade70uR0TkpAzoQAeYWZqLL8aYv0y9dBGJbAM+0IcmJ3DuaVk8taxKc9JFJKIN+ECHwJz0HQ1NvKl10kUkginQgfPHDSV1kOaki0hkU6ATnJNeksuLa3ZoTrqIRCwFetCsSfkcbGnj2ZXbvS5FRKRXFOhBE/JSOS07SbNdRCRiKdCDzIyrz8hnaUUtm3ZpTrqIRJ5uA93M7jOznWa2uov9qWb2DzNbYWZrzGx26MvsH1eW5hFjqJcuIhGpJz30B4CLj7P/dmCtc64YmAH8j5nFnXxp/W9oSmBO+pOaky4iEajbQHfOLQT2HK8JkGxmBiQF27aEprz+N2tSAdvrm1j0geaki0hkCcUY+m+AccA2YBXwZedc27EamtkcMys3s/Jdu8LzgyXa56TrQ6RFJNKEItAvApYDuUAJ8BszSzlWQ+fcPOdcmXOuLCsrKwSnDr2EWB+XF+fywpodNDRpTrqIRI5QBPps4EkXsBH4EBgbguN65upJ+TQ1t/Gc5qSLSAQJRaBvAc4HMLNsYAywKQTH9UxxfiqnDk3SUgAiElF6Mm3xUWAxMMbMKs3sZjOba2Zzg01+CEwzs1XAP4FvOeci+opi+zrp5RW1fLh7n9fliIj0iL+7Bs65a7vZvw24MGQVhYkrS/P48Qvrmb+0km9cNMbrckREuqU7RbuQnZLA9NOyeHJZJW2aky4iEUCBfhxXn5HPtvomFm+q8boUEZFuKdCP44Lx2SQn+HVxVEQiggL9ONrnpD+/ejuNmpMuImFOgd6NWe1z0ldpTrqIhDcFejdKCtIYlTWY+UurvC5FROS4FOjdCMxJL+DtzXvYrDnpIhLGFOg90L5O+pNaJ11EwpgCvQdyUhM4Z3QW85dVaU66iIQtBXoPzZqUT1XdAZZoTrqIhCkFeg9d2D4nXcMuIhKmFOg9lBDr47LiXJ5ftYO9ByP2A5lEJIop0E/ArEn5HGhu1Zx0EQlLCvQTUFqQximZg7UUgIiEJQX6CTAzrp6Uz9sf7qGiRnPSRSS8KNBP0FVn5GEG85fpzlERCS8K9BM0LHUQ55yaqXXSRSTsKNB7YdakfCprD/DWh3u8LkVEpIMCvRcuOj2H5Hitky4i4UWB3gsJsT4uLc7lHyu38dqGnV6XIyICKNB77RsXnsbooUnc+lA5z6zc5nU5IiIK9N4akhTPo3OmUFKQxpcefZdH397idUkiMsB1G+hmdp+Z7TSz1cdpM8PMlpvZGjP7V2hLDF8pCbE89PmzOPe0LO58chX3/OsDr0sSkQGsJz30B4CLu9ppZmnA74DLnXOnA9eEpLIIMSjOx7zry7isOJe7n1/P3c+vxzlNZxSR/ufvroFzbqGZFR6nyWeBJ51zW4LtB9xVwjh/DL/4dAkpCX7u+dcHNDQ188OZRfhizOvSRGQA6TbQe+A0INbMXgOSgV865x46VkMzmwPMARg+fHgITh0+fDHGXVcUkZYYy29f/YCGA8387FMlxPl1mUJE+kcoAt0PTALOBwYBi81siXPuvaMbOufmAfMAysrKom5cwsz45kVjSR0Uy/95bj17D7bw++smMSjO53VpIjIAhKL7WAm86Jzb55zbDSwEikNw3Ig1Z/oo7r5qAgvf28X1f3qL+gPNXpckIgNAKAL978A5ZuY3s0TgLGBdCI4b0T5z5nB+89kzWFFZx2fmLWFX40GvSxKRKNeTaYuPAouBMWZWaWY3m9lcM5sL4JxbB7wArATeBu51znU5xXEg+eSEYfzpxsls3r2Pa+5ZRGXtfq9LEpEoZl5NsSsrK3Pl5eWenLu/La2oZfb9b5MY5+fhW87k1KHJXpckIhHKzJY658qOtU9TMPrBpBHpPHbbVFraHJ/6wxJWVdZ7XZKIRCEFej8ZNyyFJ+ZOJTHOx7V/XMLiD2q8LklEoowCvR8VZg7mibnTGJaawI33v82CtdVelyQiUUSB3s9yUhN4/LapjMtJ5raHl/LUu1pTXURCQ4HugfTBcfzl1imcNTKDrz62ggcXbfa6JBGJAgp0jyTF+7nvpslcMD6b7z29hl/9830t6iUiJ0WB7qGEWB+/v+4Mrjojj5+9/B53PbtOHzwtIr0WirVc5CT4fTH8dFYxKQmx/OmND6k/0MzdV03A79P/tSJyYhToYSAmxvjeZeNJT4zj5wveo7GpmV9+ppSEWC3qJSI9p25gmDAzvvyJ0XzvsvG8uKaamx98h70HW7wuS0QiiAI9zMw+eyT/c00xSzbt4bp736J23yGvSxKRCKFAD0NXT8rn99edwbrtDXx63mKqG5q8LklEIoACPUxdeHoOD8yeTFXtAWbds4iKmn1elyQiYU6BHsamjcrkkVunsLephVn3LGb9jgavSxKRMKZAD3PFBWk8fttUYgw+/YclLNtS63VJIhKmFOgRYHR2Mk/MnUZaYizX/fEtXn9/l9cliUgYUqBHiIKMRP7f3KmMGJLI5x94h+dXbfe6JBEJMwr0CDI0OYHH5kxlYn4atz+yjMff2ep1SSISRhToESY1MZY/33wm54zO4o75K/njwk1elyQiYUKBHoES4/zce0MZ/zZhGP/93Dp+8I81uqtURLSWS6SK88fwq2tLGZIUx/1vbuYfK7bxlU+cxmcmF2hhL5EBSj/5EcwXY/zXzCL+dvvZnJKZxHf/tpqLf/k6/1xXrbXVRQYgBXoUKClI47HbpvCH6yfR2ua4+cFyrrv3LVZX1Xtdmoj0o24D3czuM7OdZra6m3aTzazFzGaFrjzpKTPjotNzeOmr0/nB5aezbnsDl/3mDb72+HK21R3wujwR6Qc96aE/AFx8vAZm5gP+L/BSCGqSkxDri+HGaYW89s3zmDP9FJ5ZuZ3zfvoaP3lxPY1NzV6XJyJ9qNtAd84tBPZ00+xLwHxgZyiKkpOXOiiWOy8Zxz+/di4XF+Xw21c/4LyfvsbDSypoaW3zujwR6QMnPYZuZnnAlcDve9B2jpmVm1n5rl26fb0/FGQk8svPlPL3ThdOL/rFQl04FYlCobgo+gvgW865brt9zrl5zrky51xZVlZWCE4tPVXc6cJpm4ObHyzns3/UhVORaBKKeehlwF/NDCAT+KSZtTjn/haCY0sItV84/fjYoTzy1hZ+seA9Lv31G1xVmsc3LhpDbtogr0sUkZNw0oHunBvZ/tzMHgCeUZiHt/YLp1eU5vG71zZy/5ubeXbVdm752EjmnjuK5IRYr0sUkV7oybTFR4HFwBgzqzSzm81srpnN7fvypC/pwqlIdDGvLoyVlZW58vJyT84tx7Ziax3//ew63t68h1FZg/nOJ8fx8bFDCQ6niUgYMLOlzrmyY+3TnaLSQRdORSKbAl2OcPQdp+t3NHDpr9/ga4/pjlORcKchFzmuhqZmfvfqB9z35ocY6MKpiMc05CK9lpIQy7cvGcsrXz984XTGT17jz7pwKhJ2FOjSI/nph+84HTU0if8I3nG6YK3uOBUJFwp0OSHFBWk8NmcK866fhHNwy0O6cCoSLjSGLr3W3NrGo29v4RcL3mfPvkNcfHoON04rZMopGZrqKNJHjjeGrkCXk9bQ1My8f23i4bcqqNvfzGnZSdwwtZArS/MYHK9PORQJJQW69Ium5laeXrGNBxdtZs22BpLj/cwqy+f6KSM4JSvJ6/JEooICXfqVc45lW+p4aPFmnlu1neZWx/TTsrhx6ghmjBmKL0bDMSK9pUAXz+xsbOKvb2/lL29VUN1wkIKMQVw/ZQSfKisgLTHO6/JEIo4CXTzX3NrGS2uqeXDxZt7+cA/x/hhmluRyw9RCivJSvS5PJGIo0CWsrNvewEOLK/jbu1UcaG5l0oh0bpg6gkuKhhHn10xakeNRoEtYqt/fzP9bupU/L6mgomY/mUnxfPas4Vx31nCyUxK8Lk8kLCnQJay1tTkWvr+LhxZX8OqGnfjMuKgohxunFjK5MF1z2kU6OV6ga5KweC4mxpgxZigzxgylomYfDy+p4LF3tvLsyu2MG5bCjVNHMLMkj0FxPq9LFQlr6qFLWDpwqJW/L6/igUWbWb+jkZQEP58qK+D6qSMYMWSw1+WJeEZDLhKxnHOUV9Ty4KLNvLB6B63OMeO0LG6YVsi5o7OI0Zx2GWA05CIRy8yYXJjB5MIMqhuaeOStLTzy9hZm3/8OhUMS+dyUEVxTVkDqIK3PLqIeukScQy1tvLBmBw8t2kx5RS2DYn1cUZrHDVNHMG5YitflifQpDblI1FpdVc+fF1fwt+VVHGxp48yRGVxZmsf544YyNFlTHyX6KNAl6tXtP8Tj5Vv5y1tbqKjZjxmUFKRxwfhsLhyfw6lDtTiYRIeTCnQzuw+4FNjpnCs6xv7rgG8BBjQC/8s5t6K7ohTo0hecc6zf0cjLa6t5eW01q4IfvHFK5mAuGJ/NBeOzKR2ergXCJGKdbKBPB/YCD3UR6NOAdc65WjO7BPi+c+6s7opSoEt/2F5/gAVrq3lpbTVLNtXQ3OoYMjiO88cN5YLxOXxsdCYJsZrfLpHjpIdczKwQeOZYgX5Uu3RgtXMur7tjKtClvzU0NfPahl28vLaa19bvpPFgC4NifXxsdCYXjM/m/HHZZAzWCpAS3vpz2uLNwPMhPqZISKQkxHJ5cS6XF+dyqKWNtz6s6RiaeWltNTEGZSMyuPD0wNCMbmCSSBOyHrqZnQf8DjjHOVfTRZs5wByA4cOHT6qoqOhNzSIh5ZxjdVUDL6/dwUtrq1m/oxGA07KTguPuOUzMS9VNTBIW+nzIxcwmAk8Blzjn3utJURpykXC1dc/+jp7725v30NrmyE6J5xPjAj33qaOGEO/XuLt4o08D3cyGA68ANzjnFvW0KAW6RIK6/Yd4Zf1OXl5bzb/e28X+Q60kxfs597QsLhifzXljhpKaqLtUpf+c7CyXR4EZQCZQDXwPiAVwzt1jZvcCVwPt4yctXZ2sMwW6RJqm5lYWf1DDS2urWbCuml2NB/HHGGeOzODC8dl8Ynw2+emJXpcpUU43FomEWFubY3llXcfQzMadewEYPyylY777+GEpGneXkFOgi/SxD3fv4+W1O3h5bTXlFbU4B0MGxzHt1EzOHjWEs0/NpCBDvXc5eQp0kX5Us/cgr27YxaKNu3lj4252Nh4EoCBjEOecmsm0UZlMGzWEIUnxHlcqkUiBLuIR5xwf7NrLmxtreGPjbpZsqqGxqQWAccNSAr330ZmcWZjB4HitZi3dU6CLhImW1jZWVdWz6IMa3ty4m/KKWg61tOGPMUqHp3H2qZmcfWomJQVpxPpivC5XwpACXSRMNTW3Ur65ljc/2M2ijbtZWVWPc5AY5+OskRmcHRyiGZuTrAusAugTi0TCVkKsj3NGZ3LO6EwA6vc3s3hTDYs+CIy/v/rsOiBwgXVq8OLqObrAKl1QoIuEkdTEWC4uyuHiohwgsFrkoo2B4Zk3P9jNMyu3A4ELrGePygz24HWBVQI05CISIQIXWPcFwn3jbhZ3usA6NieZc4Lj72eO1AXWaKYxdJEo1NLaxuptDby5cTeLPtjNO5uPvMDaPj1yYn4ag+K09ky0UKCLDABNza0srajt6MGvqqqnzYEvxhibk0xJQRolBWmUDk/nlMzBusgaoRToIgNQ/YFmyjfvYfnWOt7dUseKrXU0HgwM0SQn+DsFfBolBen6cI8IoVkuIgNQ6qBYzh8X+CQmCKw/s2n3XpZtqWP51jqWb6njt69upC3Ypxuekdgp4NMYn5uiZYIjjHroIgPY/kMtrKqsDwR8sCe/o6EJgDhfDONyUyjtFPLDMxIx01CNlzTkIiI9tqO+ieVba3k3GPCrKus50NwKQMbguI6hmpKCNIoL0kgdpPXg+5OGXESkx3JSE7g4dRgXFw0DArNp3qvey7tba1keHK55dcNO2vuCo7IGU1KQTsnwNEoL0hibk4xfyxZ4Qj10ETlhDU3NrNxaH+jJB0O+Zt8hABJiY5iQl9oxo6akII1hqQkaqgkRDbmISJ9yzlFZeyA4TFPL8q11rKlq4FBrGwBDk+OZmJ/KxPw0JuSnUpyfplk1vaQhFxHpU2ZGQUYiBRmJXF6cC8ChljbWbW/g3S21rKisZ2VlHf9cf3ioJj99UEfIT8xPpSgvlZQEjcefDAW6iPSJOH8MxcELp+0am5pZXdXAyso6VlbWs7KqjudW7ejYf0rWYCbmBUK+uCCV8cNSdZfrCVCgi0i/SU6IZeqoIUwdNaRj2559h1hVVc/KrXWsrKpn8aYa/rZ8GxC4y3X00KSOnnxxfhpjcpKJ8+ui67FoDF1Ewk51Q1OgB9/ek6+so3Z/MxCcHz8smQmdQv7UoUn4BshSBrooKiIRrf2ia+eQX1VVz97gUgaDYn0U5aUwIS8wVDMxP40RGYlRuV6NAl1Eok5bm+PDmn2srKxjxdZAwK/ZVk9Tc2BmTXKCn4n5qYGQz09lYkEauVEwffKkZrmY2X3ApcBO51zRMfYb8Evgk8B+4Cbn3LKTK1lE5PhiYoxRWUmMykriytJ8IHAT1Ps793YaqqnnT29sork10HFNS4zltOxkxuYkd3wdnZ0cNXe79uSi6APAb4CHuth/CTA6+DgL+H3wq4hIv/L7Yhg3LIVxw1L49OTAtoMtrazf3sjKyjrWbm/kvepGnlpW1bHyJMCw1ATG5CQzJjsQ9GNykjl1aBIJsZE1w6bbQHfOLTSzwuM0mQk85AJjN0vMLM3MhjnntoeqSBGR3or3+z4yfdI5x7b6Jt7b0cj6HYGQX7+jkUUbazpuhooxKMwc3BHyY3OSOS0nmREZiWG7tEEopi3mAVs7va4MbvtIoJvZHGAOwPDhw0NwahGRE2dm5KUNIi9tEOeNHdqxvaW1jc01+9mwo5EN1Y1s2NHA+h2NvLBmR8cNUXH+GEYPTQoEfU5yR88+HJY36Nd56M65ecA8CFwU7c9zi4REWxu0Hjry0XKw0/NO2znOP/EeTUbopk1Pj9HWBq790Xr4eVtr4Bgf2dbezh1jW3u7ro7XduSj87bOf6aO2rt6fYz3oNvv6enrro/rB04FTnWOfwOIAwocrXmOvU0t7D3YQmNTM3sPtLB3fTMHV7fRjGM1sMFnJMX7GRzvJzneF/zqJ9ZnHz3/+JlQeh2hFopArwIKOr3OD26TUHIOWpuhpSkYHE2BIGk5CG3N0NYS/CFrDT5vDT5v7fS85fAP2BFtjrG9rSW4r63T8/btbUe1CT7MAAt8tZjgo/PzmOD+o7d31abzvh62wQJ1fSRgDx4Zth1B3BzY19ocfN35efu+9uMcDBxbjtT578V8ga8xvk5/L77DfzcQfE4vXnfedvTrEzxml8c9xjbAZ0YqkIoFQj4OSDFa2xwHW1ppanEcbA58bWpopaHNUR/8fn+MER/rIyHW1/HVv7+evljJJhSB/jTwRTP7K4GLofVRPX7e1gpN9XCgFg42HA7V9jDoeN05eA8dO4iP26b9dadjdtdj60vmgxh/8AfVF/ja8Ty4HYK9MffRnhruGPu6aBf64sEfD7548MeBr9PDHxfY7ouDuETwpQW3xx/VLv6o74nr4njx4IsNPLfuxll78Ot5t7/C9+AYHeHaOWw7B7AdY1vndt18b4RPAzwZPiAx+GjnnKO64SDrdzTwXnUjG3bsZUN1A+9X7+VgS+A3lc+PGMl/9kE9PZm2+CgwA8g0s0rge0BssPB7gOcITFncSGDa4uw+qDP0Wg4GQrnjUXfU61poOta2BnoVOjF+8CcEf+gTAkFwxOt4SMwIhkR8F23aX7e3CT5iYg+H6hGB6w8+j+n0vH17zFFtjrO9o/fbTz4S9F39B9FFuxj/kUHt0woX0n/MjJzUBHJSE5gx5vD4fGubo6JmH+9VN5KfnnicI/ReT2a5XNvNfgfcHrKKToRzcLCx6/A9Iqjrjgzq5v1dH9diICENBqUHHomZMOTUw6/bH/HJh8O4u7Bu78FK99p7hOg9k+jhizFOyUrilKykPjtH5HVd3n8ZXrjzcDi71q7b+hOODOD0QsgthUFpnbanfTSo45IDvVoRkQgSeYE+KB1yio4M4M696c5BHTvI62pFRPpN5AV6fhlc84DXVYiIhB2NK4iIRAkFuohIlFCgi4hECQW6iEiUUKCLiEQJBbqISJRQoIuIRAkFuohIlPDsQ6LNbBdQ0ctvzwR2h7CcSKf340h6Pw7Te3GkaHg/Rjjnso61w7NAPxlmVt7Vp14PRHo/jqT34zC9F0eK9vdDQy4iIlFCgS4iEiUiNdDneV1AmNH7cSS9H4fpvThSVL8fETmGLiIiHxWpPXQRETmKAl1EJEpEXKCb2cVmtsHMNprZt72ux0tmVmBmr5rZWjNbY2Zf9romr5mZz8zeNbNnvK7Fa2aWZmZPmNl6M1tnZlO9rskrZvbV4M/IajN71MwSvK6pL0RUoJuZD/gtcAkwHrjWzMZ7W5WnWoCvO+fGA1OA2wf4+wHwZWCd10WEiV8CLzjnxgLFDND3xczygP8NlDnnigh8+vhnvK2qb0RUoANnAhudc5ucc4eAvwIzPa7JM8657c65ZcHnjQR+YPO8rco7ZpYP/Btwr9e1eM3MUoHpwJ8AnHOHnHN1nhblLT8wyMz8QCKwzeN6+kSkBXoesLXT60oGcIB1ZmaFQCnwlseleOkXwB1Am8d1hIORwC7g/uAQ1L1mNtjrorzgnKsCfgpsAbYD9c65l7ytqm9EWqDLMZhZEjAf+IpzrsHrerxgZpcCO51zS72uJUz4gTOA3zvnSoF9wIC85mRm6QR+kx8J5AKDzexz3lbVNyIt0KuAgk6v84PbBiwziyUQ5n9xzj3pdT0eOhu43Mw2ExiK+7iZPextSZ6qBCqdc+2/sT1BIOAHok8AHzrndjnnmoEngWke19QnIi3Q3wFGm9lIM4sjcGHjaY9r8oyZGYEx0nXOuZ95XY+XnHN3OufynXOFBP5dvOKci8peWE8453YAW81sTHDT+cBaD0vy0hZgipklBn9mzidKLxD7vS7gRDjnWszsi8CLBK5U3+ecW+NxWV46G7geWGVmy4PbvuOce867kiSMfAn4S7DzswmY7XE9nnDOvWVmTwDLCMwMe5coXQJAt/6LiESJSBtyERGRLijQRUSihAJdRCRKKNBFRKKEAl1EJEoo0EVEooQCXUQkSvx/EM2nyVyyx6cAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9886\n"
     ]
    }
   ],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import * # tqdm用于显示进度条并评估任务时间开销\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "# 设置随机种子\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# 定义模型、优化器、损失函数\n",
    "model = LeNet()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.02)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 设置数据变换和数据加载器\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # 将数据转换为张量\n",
    "])\n",
    "\n",
    "# 加载训练数据\n",
    "train_dataset = datasets.MNIST(root='../data/mnist/', train=True, download=True, transform=transform)\n",
    "# 实例化训练数据加载器\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "# 加载测试数据\n",
    "test_dataset = datasets.MNIST(root='../data/mnist/', train=False, download=True, transform=transform)\n",
    "# 实例化测试数据加载器\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)\n",
    "\n",
    "# 设置epoch数并开始训练\n",
    "num_epochs = 10  # 设置epoch数\n",
    "loss_history = []  # 创建损失历史记录列表\n",
    "acc_history = []   # 创建准确率历史记录列表\n",
    "\n",
    "# tqdm用于显示进度条并评估任务时间开销\n",
    "for epoch in tqdm(range(num_epochs), file=sys.stdout):\n",
    "    # 记录损失和预测正确数\n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    \n",
    "    # 批量训练\n",
    "    model.train()\n",
    "    for inputs, labels in train_loader:\n",
    "\n",
    "        # 预测、损失函数、反向传播\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # 记录训练集loss\n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    # 测试模型，不计算梯度\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in test_loader:\n",
    "\n",
    "            # 预测\n",
    "            outputs = model(inputs)\n",
    "            # 记录测试集预测正确数\n",
    "            total_correct += (outputs.argmax(1) == labels).sum().item()\n",
    "        \n",
    "    # 记录训练集损失和测试集准确率\n",
    "    loss_history.append(np.log10(total_loss))  # 将损失加入损失历史记录列表，由于数值有时较大，这里取对数\n",
    "    acc_history.append(total_correct / len(test_dataset))# 将准确率加入准确率历史记录列表\n",
    "    \n",
    "    # 打印中间值\n",
    "    if epoch % 2 == 0:\n",
    "        tqdm.write(\"Epoch: {0} Loss: {1} Acc: {2}\".format(epoch, loss_history[-1], acc_history[-1]))\n",
    "\n",
    "# 使用Matplotlib绘制损失和准确率的曲线图\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot(loss_history, label='loss')\n",
    "plt.plot(acc_history, label='accuracy')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "# 输出准确率\n",
    "print(\"Accuracy:\", acc_history[-1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
